import numpy as np
import pandas as pd
import statsmodels.api as sm
import math
import sys

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        black_aud_rate = model.params[0] + model.params[1]
        non_black_aud_rate = model.params[0] 
        #print("number of obs :" + str(model.nobs))
        return coef, se, black_aud_rate, non_black_aud_rate

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
        non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
        est =  black_aud_rate - non_black_aud_rate
        return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
        seMultiplier=getSEMultiplier(dataset,pbvarb)
        #seReg = regEstimate(dataset,pbvarb,outcome)[1]
        seChen = seReg*seMultiplier
        return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
        return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))

# read in 2014 op data
keep_cols = ['taxpayer_id_new', 'taxpayer_id', 'predicted_prob_black', 'aud_no_research_audits', 'isEIC', 'activity_code', 'total_pos_inc_class']
individual2014 = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final_new_eic.csv", usecols=keep_cols)

# subset to EITC AND (activity code equals 270 or 271)
individual_eic = individual2014[(individual2014.isEIC == 1) & (individual2014.total_pos_inc_class.isin([70, 71]))]
individual_eic_270 = individual_eic[individual_eic.total_pos_inc_class == 70]
individual_eic_271 = individual_eic[individual_eic.total_pos_inc_class == 71]

# write out results
sys.stdout = open("/REDACTED/eitc_audit_stats_by_ac.txt", "w")

print("Share of EITC Claimants (270, 271)")
print(round(len(individual_eic_270) / len(individual_eic), 2))
print(round(len(individual_eic_271) / len(individual_eic), 2))

print("\nShare of Audits of EITC Claimants (270, 271)")
eic_aud = individual_eic[individual_eic.aud_no_research_audits == 1]
eic_aud_270 = eic_aud[eic_aud.total_pos_inc_class == 70]
eic_aud_271 = eic_aud[eic_aud.total_pos_inc_class == 71]

print(round(len(eic_aud_270) / len(eic_aud), 2))
print(round(len(eic_aud_271) / len(eic_aud), 2))

print("\nAudit Rate (270, 271)")
print(round(100*individual_eic_270.aud_no_research_audits.mean(), 2))
print(round(100*individual_eic_271.aud_no_research_audits.mean(), 2))

lin_270 = regEstimate(individual_eic_270, 'predicted_prob_black', 'aud_no_research_audits')
prob_270 = chenEstimate(individual_eic_270, 'predicted_prob_black', 'aud_no_research_audits')
#getSEs(individual_eic_270, 'predicted_prob_black', 'aud_no_research_audits', lin[1])

lin_271 = regEstimate(individual_eic_271, 'predicted_prob_black', 'aud_no_research_audits')
prob_271 = chenEstimate(individual_eic_271, 'predicted_prob_black', 'aud_no_research_audits')
#getSEs(individual_eic_271, 'predicted_prob_black', 'aud_no_research_audits', lin_eic_271[1])

print('\nDisparity (Probabilistic)(270, 271)')
print(round(100*prob_270[0], 2))
print(round(100*prob_271[0], 2))

print('\nDisparity (Linear)(270, 271)')
print(round(100*lin_270[0], 2))
print(round(100*lin_271[0], 2))

print('\nBlack Audit Rate (Probabilistic)(270, 271)')
print(round(100*prob_270[1], 2))
print(round(100*prob_271[1], 2))

print('\nBlack Audit Rate (Linear)(270, 271)')
print(round(100*lin_270[2], 2))
print(round(100*lin_271[2], 2))

print('\nNon-Black Audit Rate (Probabilistic)(270, 271)')
print(round(100*prob_270[2], 2))
print(round(100*prob_271[2], 2))

print('\nNon-Black Audit Rate (Linear)(270, 271)')
print(round(100*lin_270[3], 2))
print(round(100*lin_271[3], 2))

print('\nN (270, 271)')
print(len(individual_eic_270))
print(len(individual_eic_271))

sys.stdout = sys.__stdout__
